function W = ConsumerSurplus_count(NB, MNL, Xa_mnl, MP, Vec)

% Function calculates Consumer surplus using results from both, MNL and
% count data model (NBP)

% Returns W, where
    % W.CS_count is simple CS, per trip calculated using results from count
        % data model
    % W.CS_tot is a total CS using results from choice model
    % W.CS_mnl_change - per trip CS loss per site (or type of sites)
    % W.CS_change_tot  - total trip CS loss per site (or type of sites)
  
if isempty(Vec) % Vec identifies sets of sites for wchich welfare should be calculated
   Vec = 1:MNL.EstimOpt.NAlt; % If not defined calculate welfare for each site separately
end
rng(100001);

%% CS per trip
Hess = NB.ihess; 
NVarA = length(NB.DetailsA(:,1));

CS = MP(2)/NB.DetailsA(2,1); %CS in Count data model should enter second
h = @(B) MP(2)/B(2);
H = jacobianest(h, NB.bhat);
W.CS_count = [CS, sqrt(H*Hess*H')]; % CS per trip

%% Total CS
h = @(B) MP(2)*mean(exp(NB.INPUT.Xa*B(1:NVarA))/B(2));
H = jacobianest(h, NB.bhat);
CS_tot = MP(2)*mean(exp(NB.INPUT.Xa*NB.bhat(1:NVarA))/NB.bhat(2));
W.CS_tot = [CS_tot, sqrt(H*Hess*H')];

%% Taking into account MNL results (CS loss)
Hess2 = MNL.ihess; 
NP =  NB.EstimOpt.NP;
NAlt =  MNL.EstimOpt.NAlt;
Xa =  Xa_mnl;
MP(1) = -MP(1)*sign(Xa(2,end));

Uni = unique(Vec);
h = @(B) MP(1)*mean(log(sum(reshape(exp(Xa*B), [NAlt, NP]),1)')/B(end)); % TC needs to enter last
H = jacobianest(h, MNL.bhat);
CS_mnl = MP(1)*mean(log(sum(reshape(exp(Xa*MNL.bhat), [NAlt, NP]),1)')/MNL.bhat(end));
W.CS_mnl = [CS_mnl, sqrt(H*Hess2*H')]; % CS per_trip (MNL) 

CS = MP(1)*log(sum(reshape(exp(Xa*MNL.bhat), [NAlt, NP]),1)')/MNL.bhat(end);
W.CS_test = CS;

CS_loss = zeros(length(CS), length(Uni));

CS_tmp = reshape(exp(Xa*MNL.bhat), [NAlt, NP]);
for i = 1:length(Uni)
    Indx = Uni(i);
   CS_loss(:,i) = MP(1)*log(sum(CS_tmp(Vec ~= Indx,:),1)')/MNL.bhat(end);
end
W.CS_mnl_change = mean(CS - CS_loss)';
Xa =  NB.INPUT.Xa;
W.CS_change_tot = zeros(length(Uni),2);
for i = 1:length(Uni)
    Xa_tmp = Xa;
    Xa_tmp(:,2) = CS_loss(:,i);
    Lam = exp(Xa*NB.bhat(1:NVarA))/NB.bhat(2);
    Lam2 = exp(Xa_tmp*NB.bhat(1:NVarA))/NB.bhat(2);
    h = @(B) MP(2)*mean(exp(Xa*B(1:NVarA))/B(2) - exp(Xa_tmp*B(1:NVarA))/B(2));
    H = jacobianest(h, NB.bhat);
    W.CS_change_tot(i,:) = [MP(2)*mean(Lam - Lam2), sqrt(H*Hess*H')];
end
    
end


